/**
 * Copyright (c) 2011, Tim Buktu
 * All rights reserved.
 * <p>
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * <p>
 * 1. Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 * <p>
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

package com.example.emailWeb.ntru.util;

import com.example.emailWeb.ntru.exception.NtruException;

import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.Arrays;

/**
 * Converts a coefficient array to a compact byte array and vice versa.
 * 将系数数组转换为紧凑字节数组，反之亦然。
 */
public class ArrayEncoder implements Serializable {
    /**
     * Bit string to coefficient conversion table from P1363.1. Also found at
     * {@link http://stackoverflow.com/questions/1562548/how-to-make-a-message-into-a-polynomial}
     * <p/>
     * Convert each three-bit quantity to two ternary coefficients as follows, and concatenate the resulting
     * ternary quantities to obtain [the output].
     * <p/>
     * <code>
     * {0, 0, 0} -> {0, 0}<br/>
     * {0, 0, 1} -> {0, 1}<br/>
     * {0, 1, 0} -> {0, -1}<br/>
     * {0, 1, 1} -> {1, 0}<br/>
     * {1, 0, 0} -> {1, 1}<br/>
     * {1, 0, 1} -> {1, -1}<br/>
     * {1, 1, 0} -> {-1, 0}<br/>
     * {1, 1, 1} -> {-1, 1}<br/>
     左右两边的值是一一对应的，表示的值是0到7。
     左边是2进制，2^3 = 8；
     右边是3进制，3^2 = 9；右边能表示的范围比左边多一个，规定{-1, -1}不允许使用。3进制中的2用-1表示。
     * </code>
     */

    // -1代表2; 三进制,(1,0)->1*3+0 = 3; (-1,1)->2*3+1=7
    private static final int[] COEFF1_TABLE = {0, 0, 0, 1, 1, 1, -1, -1};//8个{int,int}所有的第一个数。
    private static final int[] COEFF2_TABLE = {0, 1, -1, 0, 1, -1, 0, 1};//8个{int,int}所有的第二个数。

    /**
     * Coefficient to bit string conversion table from P1363.1. Also found at
     * {@link http://stackoverflow.com/questions/1562548/how-to-make-a-message-into-a-polynomial}
     * <p/>
     * Convert each set of two ternary coefficients to three bits as follows, and concatenate the resulting bit
     * quantities to obtain [the output]:
     * <p/>
     * <code>
     * {-1, -1} -> set "fail" to 1 and set bit string to {1, 1, 1}
     * {-1, 0} -> {1, 1, 0}<br/>
     * {-1, 1} -> {1, 1, 1}<br/>
     * {0, -1} -> {0, 1, 0}<br/>
     * {0, 0} -> {0, 0, 0}<br/>
     * {0, 1} -> {0, 0, 1}<br/>
     * {1, -1} -> {1, 0, 1}<br/>
     * {1, 0} -> {0, 1, 1}<br/>
     * {1, 1} -> {1, 0, 0}<br/>
     * </code>
     */
    //下述数组长度为 9，只有8个对应的转换关系。3个BIT_TABLE数组的第一个数全部设为1。
    private static final int[] BIT1_TABLE = {1, 1, 1, 0, 0, 0, 1, 0, 1};
    private static final int[] BIT2_TABLE = {1, 1, 1, 1, 0, 0, 0, 1, 0};
    private static final int[] BIT3_TABLE = {1, 0, 1, 0, 0, 1, 1, 1, 0};
    private static final BigInteger THREE = BigInteger.valueOf(3);

    /**
     * Encodes an int array whose elements are between 0 and <code>q</code>,
     * to a byte array leaving no gaps between bits.<br/>
     * <code>q</code> must be a power of 2.
     * @param a the input array
     * @param q the modulus
     * @return the encoded array
     *
     * 示例：
        int[] a = {1,300,3,4};
        ArrayEncoder.encodeModQ(a,512) -> [1, 88, 14, 32, 0]
        512, 2^9 == 512
        {1,300,3,4} 每个int只留下9个bit
        1   -> 0 0000 0001
        300 -> 1 0010 1100
        3   -> 0 0000 0011
        4   -> 0 0000 0100

        从低位开始取9个bit合成int
        0000 0001 -> 1
        0101 1000 -> 88
        0000 1110 -> 14
        0010 0000 -> 32
        0000 0000 -> 0

        decodeModQ() 就是反过程，例子不再给出。
     */
    public static byte[] encodeModQ(int[] a, int q) {
        //Integer.numberOfLeadingZeros(q),
        // 返回的最高阶的（“最左边的”）中所指定的二进制补码表示的一个位前述零个比特的数量 int值。
        int bitsPerCoeff = 31 - Integer.numberOfLeadingZeros(q);//q的实际有效位的位数。
        int numBits = a.length * bitsPerCoeff;
        int numBytes = (numBits + 7) / 8;//通过bit数，求byte数。相当于除8后，向上取整。
        byte[] data = new byte[numBytes];
        int bitIndex = 0;
        int byteIndex = 0;

        for (int i = 0; i < a.length; i++) {
            //每个int
            String s = Integer.toBinaryString(a[i]);
            for (int j = 0; j < bitsPerCoeff; j++) {
                //每个int值，占用的最大bit数为bitsPerCoeff。
                // 重要：每个int只采用固定长度(bitsPerCoeff)的bit数，不是8的倍数，在下一个for循环中下一个int的bits会补上。
                int currentBit = (a[i] >> j) & 1;
                data[byteIndex] |= currentBit << bitIndex;//|=：两个二进制对应位都为0时，结果等于0，否则结果等于1；
                if (bitIndex == 7) {
                    bitIndex = 0;
                    byteIndex++;
                } else
                    bitIndex++;
            }
        }
        return data;
    }

    /**
     * Like {@link #encodeModQ(int[], int)} but only returns the first <code>numBytes</code>
     * 截断，取前numBytes个字节，返回。
     * bytes of the encoding.
     * @param a the input array
     * @param q the modulus
     * @param numBytes
     * @return the encoded array
     */
    public static byte[] encodeModQTrunc(int[] a, int q, int numBytes) {
        int bitsPerCoeff = 31 - Integer.numberOfLeadingZeros(q);
        byte[] data = new byte[numBytes];
        int bitIndex = 0;
        int byteIndex = 0;
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < bitsPerCoeff; j++) {
                int currentBit = (a[i] >> j) & 1;
                data[byteIndex] |= currentBit << bitIndex;
                if (bitIndex == 7) {
                    bitIndex = 0;
                    byteIndex++;
                    if (byteIndex >= numBytes)
                        return data;
                } else
                    bitIndex++;
            }
        }
        return null;
    }

    /**
     * Decodes a <code>byte</code> array encoded with {@link #encodeModQ(int[], int)} back to an <code>int</code> array.<br/>
     * <code>N</code> is the number of coefficients. <code>q</code> must be a power of <code>2</code>.<br/>
     * Ignores any excess bytes.
     * @param data an encoded ternary polynomial
     * @param N number of coefficients
     *          我的观点：解码需要传递数据的长度，如果解码由解密人实现的时候，解密人得有这个N
     * @param q 模q，最大的数不超过q
     * @return an array containing <code>N</code> coefficients between <code>0</code> and <code>q-1</code>
     */
    public static int[] decodeModQ(byte[] data, int N, int q) {
        int[] coeffs = new int[N];
        int bitsPerCoeff = 31 - Integer.numberOfLeadingZeros(q);
        int mask = -1 >>> (32 - bitsPerCoeff);   // for truncating values to bitsPerCoeff bits
        int byteIndex = 0;
        int bitIndex = 0;   // next bit in data[byteIndex]
        int coeffBuf = 0;   // contains (bitIndex) bits
        int coeffBits = 0;   // length of coeffBuf
        int coeffIndex = 0;   // index into coeffs
        while (coeffIndex < N) {
            // copy bitsPerCoeff or more into coeffBuf
            // 当初encodeModQ(int[] a, int q) 编码时，并不是按照一个int，4个byte，添加进byte数组的。
            // 最大不超过q，若q的高位有一个0 byte，那么就可以少用一个byte。
            // 将固定长度的byte合成一个int。
            while (coeffBits < bitsPerCoeff) {
                coeffBuf += (data[byteIndex] & 0xFF) << coeffBits;
                coeffBits += 8 - bitIndex; // ?
                byteIndex++;
                bitIndex = 0;
            }

            // low bitsPerCoeff bits = next coefficient
            coeffs[coeffIndex] = coeffBuf & mask;// 将int加入数组
            coeffIndex++;

            coeffBuf >>>= bitsPerCoeff; //重置coeffBuf，清0
            coeffBits -= bitsPerCoeff;
        }
        return coeffs;
    }

    /**
     * Decodes data encoded with {@link #encodeModQ(int[], int)} back to an <code>int</code> array.<br/>
     * <code>N</code> is the number of coefficients. <code>q</code> must be a power of <code>2</code>.<br/>
     * Ignores any excess bytes.
     * @param is an encoded ternary polynomial
     * @param N number of coefficients
     * @param q
     * @return the decoded polynomial
     */
    public static int[] decodeModQ(InputStream is, int N, int q) throws IOException {
        int qBits = 31 - Integer.numberOfLeadingZeros(q);
        int size = (N * qBits + 7) / 8;
        byte[] arr = ArrayEncoder.readFullLength(is, size);
        return decodeModQ(arr, N, q);
    }

    /**
     * Decodes a <code>byte</code> array encoded with {@link #encodeMod3Sves(int[], boolean)} back to an <code>int</code> array
     * with <code>N</code> coefficients between <code>-1</code> and <code>1</code>.<br/>
     * Ignores any excess bytes.<br/>
     * See P1363.1 section 9.2.2.
     * 将字节数组 转成 int[]存放的三元多项式序列。
     * @param data an encoded ternary polynomial 输入示例:[-19, 82, -84, -4, -11, 93, 85, -94, 53, -84,
     * @param N number of coefficients;  N是data长度的16/3倍。
     * @param skipFirst whether to leave the constant coefficient zero and start populating at the linear coefficient
     * @return the decoded coefficients
     */
    public static int[] decodeMod3Sves(byte[] data, int N, boolean skipFirst) {
        int[] coeffs = new int[N];
        int coeffIndex = skipFirst ? 1 : 0;
        int i = 0;
        while (i < data.length / 3 * 3 && coeffIndex < N - 1) {
            // process 24 bits at a time in the outer loop
            // 问：为何不是4个byte一起算
            //若前3个byte是 113, 96, 21
            //chunk是 (bin)000101010110000001110001 -> (int)1400945
            int chunk = (data[i++] & 0xFF) | ((data[i++] & 0xFF) << 8) | ((data[i++] & 0xFF) << 16); // 3个byte结合起来的int值
//            System.out.println("chunnk:"+chunk);
            for (int j = 0; j < 8 && coeffIndex < N - 1; j++) {
                // process 3 bits at a time in the inner loop
                int coeffTableIndex = ((chunk & 1) << 2) + (chunk & 2) + ((chunk & 4) >> 2);   // low 3 bits in reverse order
                coeffs[coeffIndex++] = COEFF1_TABLE[coeffTableIndex];
                coeffs[coeffIndex++] = COEFF2_TABLE[coeffTableIndex];
                chunk >>= 3; // 去除低3位
            }
        }
        return coeffs;
    }

    /**
     * Encodes an <code>int</code> array whose elements are between <code>-1</code> and <code>1</code>, to a byte array.
     * <code>coeffs[2*i]</code> and <code>coeffs[2*i+1]</code> must not both equal -1 for any integer </code>i<code>,
     * so this method is only safe to use with arrays produced by {@link #decodeMod3Sves(byte[], int, boolean)}.<br/>
     * See P1363.1 section 9.2.3.
     * 将int[]存放的三元多项式序列，转成字节数组；
     * @param arr 输入示例：[1, -1, 1, -1, -1, 0, 1, 1, 1, -1, 0, 0...]
     * @param skipFirst whether to skip the constant coefficient
     * @return the encoded array
     * @throws NtruException if <code>(-1,-1)</code> is encountered
     */
    public static byte[] encodeMod3Sves(int[] arr, boolean skipFirst) {
        int numBits = (arr.length * 3 + 1) / 2; // eg:{-1, 0} -> {1, 1, 0};1.5倍
        int numBytes = (numBits + 7) / 8;
        byte[] data = new byte[numBytes];
        int byteIndex = 0;
        int start = skipFirst ? 1 : 0;
        int end = skipFirst ? (arr.length - 1) | 1 : arr.length / 2 * 2;   // if there is an odd number of coeffs, throw away the highest one
        int i = start;
        while (i < end) {
            // process 24 bits at a time in the outer loop
            int chunk = 0;
            int chunkBits = 0;   // #bits in the chunk
            while (chunkBits < 24 && i < end) {
                int coeff1 = arr[i++] + 1;
                int coeff2 = arr[i++] + 1;
                if (coeff1 == 0 && coeff2 == 0)
                    //{-1, -1} -> set "fail"
                    throw new NtruException("Illegal encoding!");

                int bitTableIndex = coeff1 * 3 + coeff2;
                chunk |= BIT1_TABLE[bitTableIndex] << chunkBits++;
                chunk |= BIT2_TABLE[bitTableIndex] << chunkBits++;
                chunk |= BIT3_TABLE[bitTableIndex] << chunkBits++;
            }
            data[byteIndex++] = (byte) (chunk & 0xFF);
            if (byteIndex < data.length)
                data[byteIndex++] = (byte) ((chunk >> 8) & 0xFF);
            if (byteIndex < data.length)
                data[byteIndex++] = (byte) ((chunk >> 16) & 0xFF);
        }
        return data;
    }

    /**
     * Encodes an <code>int</code> array whose elements are between <code>-1</code> and <code>1</code>, to a byte array.
     * @return the encoded array
     */
    public static byte[] encodeMod3Tight(int[] intArray) {
        BigInteger sum = BigInteger.ZERO;
        for (int i = intArray.length - 1; i >= 0; i--) {
            sum = sum.multiply(THREE);
            sum = sum.add(BigInteger.valueOf(intArray[i] + 1));
        }

        int size = (THREE.pow(intArray.length).bitLength() + 7) / 8;
        byte[] arr = sum.toByteArray();

        if (arr.length < size) {
            // pad with leading zeros so arr.length==size
            byte[] arr2 = new byte[size];
            System.arraycopy(arr, 0, arr2, size - arr.length, arr.length);
            return arr2;
        }

        if (arr.length > size)
            // drop sign bit
            arr = Arrays.copyOfRange(arr, 1, arr.length);
        return arr;
    }

    /**
     * Converts a byte array produced by {@link #encodeMod3Tight(int[])} back to an <code>int</code> array.
     * @param b a byte array
     * @param N number of coefficients
     * @return the decoded array
     */
    public static int[] decodeMod3Tight(byte[] b, int N) {
        System.out.println("b:"+Arrays.toString(b));
        BigInteger sum = new BigInteger(1, b);
        System.out.println("sum:"+sum);
        int[] coeffs = new int[N];
        for (int i = 0; i < N; i++) {
            coeffs[i] = sum.mod(THREE).intValue() - 1;
            if (coeffs[i] > 1)
                coeffs[i] -= 3;
            sum = sum.divide(THREE);
        }
        return coeffs;
    }

    /**
     * Converts data produced by {@link #encodeMod3Tight(int[])} back to an <code>int</code> array.
     * @param is an input stream containing the data to decode
     * @param N number of coefficients
     * @return the decoded array
     */
    public static int[] decodeMod3Tight(InputStream is, int N) throws IOException {
        int size = (int) Math.ceil(N * Math.log(3) / Math.log(2) / 8);
        byte[] arr = ArrayEncoder.readFullLength(is, size);
        return decodeMod3Tight(arr, N);
    }

    /**
     * Reads a given number of bytes from an <code>InputStream</code>.
     * If there are not enough bytes in the stream, an <code>IOException</code>
     * is thrown.
     * @param is
     * @param length
     * @return an array of length <code>length</code>
     * @throws IOException
     */
    public static byte[] readFullLength(InputStream is, int length) throws IOException {
        byte[] arr = new byte[length];
        if (is.read(arr) != arr.length)
            throw new IOException("Not enough bytes to read.");
        return arr;
    }

    /** Concatenates two or more <code>byte</code> arrays. */
    public static byte[] concatenate(byte[]... arrays) {
        int totalLength = 0;
        for (byte[] arr : arrays)
            totalLength += arr.length;
        byte[] bigArr = new byte[totalLength];

        int idx = 0;
        for (byte[] arr : arrays) {
            System.arraycopy(arr, 0, bigArr, idx, arr.length);
            idx += arr.length;
        }

        return bigArr;
    }

    /**
     * Copies the lower two bytes of an <code>int</code> into a <code>byte</code> array.
     * @param i
     * @return a byte array
     */
    public static byte[] toByteArray(int i) {
        byte[] arr = new byte[2];
        arr[0] = (byte) (i >>> 8);
        arr[1] = (byte) (i & 0xFF);
//        arr[1] = (byte)(i); // 我的观点:我觉得这种写法也是可以的。
        return arr;
    }
}